# ------------------------------------------------------------------------------
# Plots yield by stent-hat
# Updates author: Cassidy Shubatt <cshubatt@gmail.com>
# To run: bash 01_plot_yield.sh
# ------------------------------------------------------------------------------

# Libraries --------------------------------------------------------------------
library(here)
library(yaml)
library(data.table)
library(tidyverse)
library(glue)
library(testit) # assert()
library(ggplot2)
library(ggthemes)
library(viridis)
library(Metrics) # auc()

u <- modules::use(here("lib", "util.R"))
a <- modules::use(here("lib", "aesthetics.R"))
a$get_font("Optima", here::here("lib", "optima.ttf"))
temp <- here("code", "03_score_diagnostics", "temp", "01_plot_yield")
if (!dir.exists(temp)) {
  dir.create(temp)
}

# Helpers ----------------------------------------------------------------------
add_labels <- function(
                       gg, outcome, x_var, population, overall_mean, mean_outcomes, incl_titles = FALSE, ...) {
  xlabel <- case_when(
    x_var == "tile_stent_or_cabg_010_tested" ~ "Percentile of Predicted Risk",
    x_var == "tile_stent_ecg_010_tested" ~ "Percentile of Predicted Risk (ECG)",
    TRUE ~ x_var
  )
  ylabel <- case_when(
    outcome == "stent_010_day" ~ "Percutaneous Cardiac Intervention (Stent) Rate",
    outcome == "cabg_010_day" ~ "Coronary Artery Bypass Graft (Open-heart surgery) Rate",
    outcome == "stent_or_cabg_010_day" ~ "Yield of Testing",
    outcome == "test_010_day" ~ "Test Rate",
    outcome == "stress_010_day" ~ "Stress Test Rate",
    outcome == "cath_010_day" ~ "Catheterization Rate"
  )

  title <- glue("{ylabel} by Predicted Risk")
  n <- sum(mean_outcomes$n)
  subtitle <- glue("Holdout cohort ({population}), excluding patients with chronic illness history, over 80, untested AMI day of")
  caption <- glue("N = {n}")

  gg <- gg + labs(x = xlabel, y = ylabel)

  if (incl_titles) {
    gg <- gg + labs(title = title, subtitle = subtitle, caption = caption)
  }

  overall_outcomes_incl <- c()
  if(outcome %in% overall_outcomes_incl){
    overall_lab <- glue("Overall {ylabel}: {overall_mean}    ")
    gg <- gg +
      geom_hline(yintercept = overall_mean, alpha = 1, color = "black", linetype = "dashed") +
      annotate(
        "label", x = 2, y = overall_mean, label = overall_lab, hjust = 0,
        alpha = 1, family = "Optima", size = 16
      )
  }

  return(gg)
}

get_filename <- function(outcome, x_var, ...) {
  x_var_lab <- case_when(
    grepl("tile_stent_or_cabg", x_var) ~ parse_number(x_var) %>%
      str_pad(3, pad = "0") %>%
      {
        glue("tile_{.}")
      },
    TRUE ~ x_var
  )
  filename <- file.path(temp, glue("{outcome}__by__{x_var_lab}.png"))
}

# Load Data --------------------------------------------------------------------
message("Loading data...")
paths <- read_yaml(here("lib", "filepaths.yml"))
overnight_lab <- ""
cohort <- readRDS(glue(paths$analysis$test_cohort)) %>%
  filter(!exclude)

# Outcome Config ---------------------------------------------------------------
message("Preparing config table...")
config <- crossing(
  outcome = c("test_010_day", "stent_or_cabg_010_day"),
  x_var = c("tile_stent_or_cabg_010_tested")
)
color_vec <- rep(a$disc_palette[c(1, 3, 2, 4, 5)], length.out = nrow(config))
config$palette <- color_vec

# Means and Plots --------------------------------------------------------------
message("Getting mean outcomes and plotting...")
plots <- config %>%
  mutate(population = pmap(., u$get_population) %>% unlist) %>%
  mutate(overall_mean = pmap(., u$get_overall_mean, df = cohort) %>% unlist) %>%
  mutate(mean_outcomes = pmap(., u$get_mean_outcomes, df = cohort)) %>%
  mutate(gg = pmap(., a$tile_plot)) %>%
  mutate(plot = pmap(., add_labels)) %>%
  # use this line to include titles in ggplot
  # mutate(plot = pmap(., add_labels, incl_titles = T)) %>%
  mutate(filename = pmap(., get_filename) %>% unlist())

# Save -------------------------------------------------------------------------
message("Saving...")
saves <- plots %>%
  select(plot, filename) %>%
  pmap(ggsave, width = 10, height = 7, units = "in")

message("Done.")
